import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.ml.recommendation._
import scala.util.Random
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import java.util.Properties
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._


/*  观察数据分布情况 */




/**
*默认情况下，RDD 为每个HDFS 块生成一个分区，将HDFS 块大小设为典型的128M 或 64M。
*由于此文件大小为 400M 左右，所以文件被拆为 3 个或 6 个分区。这通常没什么问题，
*但由于相比简单文本处理，ALS 这类机器学习算法要消耗更多的计算资源，因此减小数据块大小以增加分区个数会更好。
*减小数据块能使Spark 处理任务的同时使用的处理器核数更多。
*可以为textFile方法设置第二个参数，用这个参数指定一个不同于默认值的分区数，这样就可以将分区数设置的更大些。比如，可以考虑将这个参数设为集群处理器总核数。
*/


//读入数据
var rootDir = "hdfs://localhost:9000/dbtaobao/dataset/"
val rawUserArtistData = sc.textFile(rootDir + "user_artist_data.txt",4)
val rawArtistData = sc.textFile(rootDir + "artist_data.txt")
val rawArtistAlias = sc.textFile(rootDir + "artist_alias.txt")


//数据清理
//文件每行包含一个用户ID、一个艺术家ID和播放次数，用空格隔开。将每行的前两个值解析为整数：用户ID\艺术家ID，
//然后将其转换为包含列user和artist的DataFrame，通过agg简单计算出两列的最大值和最小值
//最大的用户ID和艺术家ID分别为 2443548 和 10794401 这两个远小于Int.Max，所以可以对其进行转换为整型。

/*
def GetUserArtustDF( rawUserArtistData: RDD[String]): DataFrame ={
        val userArtistDF = rawUserArtistData.map{ line =>
        val Array(user, artist, _*) = line.split(' ')
        (user.toInt, artist.toInt)
        }.toDF("user","artist")
}
val userArtistDF = GetUserArtustDF( rawUserArtistData)
*/

def buildCounts(
      rawUserArtistData: Dataset[String],
      bArtistAlias: Broadcast[Map[Int,Int]]): DataFrame = {
    rawUserArtistData.map { line =>
      val Array(userID, artistID, count) = line.split(' ').map(_.toInt)
      val finalArtistID = bArtistAlias.value.getOrElse(artistID, artistID)
      (userID, finalArtistID, count)
    }.toDF("user", "artist", "count")
  }



val userArtistDF = rawUserArtistData.map{ line =>
val Array(user, artist, _*) = line.split(' ')
(user.toInt, artist.toInt)
}.toDF("user","artist")


//查看数据有无超出int范围
// 得到的统计数据，可以发现用户 id 最高为2443548，艺术家 id 最高为 10794401，都远小于 ALS 的数值型限制。
userArtistDF.agg(min("user"),max("user"),min("artist"),max("artist")).show()


//
//
//有些艺术家名字和ID没有按 \t 分割，错误处理就是放弃这些数据
//span 碰到第一个不满足条件的开始划分， 少量的行转换不成功， 数据质量问题
def GetArtistByID( rawArtistData: RDD[String]): DataFrame = {
           rawArtistData.flatMap{ line =>
           val (id, name) = line.span(_!='\t')
           if(name.isEmpty){
           None
           }else{
           try{
           Some((id.toInt,name.trim))
           }catch {
                case _:NumberFormatException => None
              }
           }
        }.toDF("id","name")
}
val  artistByID = GetArtistByID( rawArtistData)
/*
                val artistByID = rawArtistData.flatMap{ line =>
                    val (id,name) = line.span(_!= '\t')
                if (name.isEmpty){
                        None
                        } else {
                        try{
                            Some((id.toInt,name.trim))
                        }catch{
                            case _:NumberFormatException => None
                        }
                    }
                }.toDF("id","name")
*/
artistByID.show()

 /**
   * 通过文件 artist_alias.txt 得到所有艺术家的别名
   * 文件不大，每一行按照 \t 分割包含一个拼错的名字ID 还有一个正确的名字ID
   * 一些行没有第一个拼错的名字ID，直接跳过
   * @param sc Spark上下文
   * @return
   */

def GetArtistAlias( rawArtistAlias: RDD[String]):Map[Int,Int] = {
      rawArtistAlias.flatMap{ line =>
      val Array(artist,alias) = line.split("\t")
      if (artist.isEmpty){
      None
      }else{
       try{ Some((artist.toInt,alias.toInt))
          } catch{
       case _: NumberFormatException => None
              }
      }
      }.collect().toMap
}
val artistAlias =    GetArtistAlias( rawArtistAlias)
/*
                        val artistAlias = rawArtistAlias.flatMap{ line =>
                              val Array(artist,alias) = line.split("\t")
                              if (artist.isEmpty){
                              None
                              }else{
                              Some((artist.toInt,alias.toInt))
                              }
                              }.collect().toMap
*/
artistAlias.take(10)

//从第一条记录我们可以看到：6803336 映射为 1000010

//我们在artistByID里查询这两条记录：

 artistByID.filter($"id" === 6803336).show()//res9: String = Aerosmith (unplugged)
 artistByID.filter($"id" === 1000010).show()  //res10: String = Aerosmith

//显然，这条记录将 Aerosmith (unplugged) 映射为 Aerosmith

val bArtistAlias = sc.broadcast(artistAlias)


/*
如果艺术家ID存在一个不同的正规ID,我们要用artist_alias.txt将所有的艺术家ID转换为正规ID.
因为Spark集群中的每个executor都要使用到artistAlias这个变量，这时可以为其创建一个广播变量，
取名为bArtistAlias，使用广播变量时，Spark对集群中每个 executor只发送一个副本，
并且在内存中也只保存了一个副本，能够节省大量的网络流量和内存。
调用cache()让Spark在DataFrame计算好之后将其暂时存储在集群的内存里，
因为ALS算法是迭代的，因此缓存是非常有益和节省计算时间。一般默认数据是序列化为字节的形式存储在内存中的，而不是对象的形式。
结果会有所不同，原因是 最终的模型取决于初始特征向量，而这些初始特征向量是随机选择的，
ML的ALS模型和其他组件默认设置了固定的随机种子，每次都会做出相同的随机选择，通过setSeed(Random.nextLong())就可以设置一个真正的随机种子。
*/
//将数据存储在内存中
def GetTrainDF( rawUserArtistData: RDD[String],
      bArtistAlias: Broadcast[Map[Int,Int]]): DataFrame = {
      rawUserArtistData.map{ line =>
      val Array(userId,artistId,count) = line.split(' ').map(_.toInt)
      val finalArtistID = bArtistAlias.value.getOrElse(artistId, artistId)
      (userId, finalArtistID, count)
      }.toDF("user", "artist", "count")
}

//val trainData = GetTrainDF( rawUserArtistData,bArtistAlias).persist(StorageLevel.MEMORY_ONLY_SER)


val trainData = rawUserArtistData.map{
        line =>
            val Array(userId, artistId, count) = line.split(' ').map(_.toInt) 
            val finalArtistID = bArtistAlias.value.getOrElse(artistId, artistId)
            (userId, finalArtistID, count)
        // Rating(userId, finalArtistID, count)
        }.toDF("user", "artist", "count").persist(StorageLevel.MEMORY_ONLY_SER)


val Array(trainData1, testData) = trainData.randomSplit(Array(0.8, 0.2))

//查看数据
trainData.show()

//广播变量主要用于在迭代中一直需要被访问的只读变量。它将此变量缓存在每个executor 里，以减少集群网络传输消耗
//Spark 执行一个阶段（stage）时，会为待执行函数建立闭包，也就是该阶段所有任务所需信息的二进制形式。这
//个闭包包括驱动程序里函数引用的所有数据结构。Spark 把这个闭包发送到集群的每个executor 上。
//当许多任务需要访问同一个（不可变的）数据结构时，我们应该使用广播变量。它对任务闭包的常规处理进行扩展，是我们能够：
//在每个 executor 上将数据缓存为原始的 Java 对象，这样就不用为每个人物执行反序列化在多个作业和阶段之间缓存数据
//在函数最后，我们调用了 cache() 以指示 Spark 在 RDD 计算好后将其暂时存储在集群的内存里。
//这样是有益的，因为 ALS 算法是 迭代的，通常情况下至少要访问该数据 10 次以上。如果不调用 cache()，那么每次要用到 RDD 时都需要从原始数据中重新计算。

val modelnew =  ALSModel.load("/usr/local/spark/Model/modelnew")

/*val modelnew = new ALS().
    setSeed(Random.nextLong()). //设置随机种子
    setImplicitPrefs(true).
    setRank(5).
    setRegParam(0.1).
    setAlpha(1.0).
    setMaxIter(5).
    setUserCol("user").
    setItemCol("artist").
    setRatingCol("count").
    setPredictionCol("prediction").
    fit(trainData1)
  */
 modelnew.userFactors.select("features").show(truncate = false)

 def makeRecommendations( model: ALSModel,userID: Int, howMany: Int): DataFrame ={
      val toRecommend = model.itemFactors.
      select($"id".as("artist")).
      withColumn("user",lit(userID))//选择所有艺术家ID与对应的目标用户ID

      model.transform(toRecommend).
      select("artist","prediction").
      orderBy($"prediction".desc).
      limit(howMany)   //对所有艺术家评分，并返回其中分值最高的
      }

val userID_10 = testData.select("user").as[Int].take(4).distinct

userID_10.map{user =>
     val recommend = makeRecommendations(modelnew, user, 10)

     val recommendedArtistIDs = recommend.select("artist").as[Int].collect()
     val ss = artistByID.filter($"id" isin (recommendedArtistIDs:_*)).toDF("artist","name")
     ss.join(recommend,"artist").show()
     val existingArtistIDs = testData.filter($"user" === user). //找到用户对应的行
     select("artist").as[Int].collect() //收集艺术家ID的整型集合
     artistByID.filter($"id" isin (existingArtistIDs:_*)).show()
     println()
}

def Artistpredict(userID:Int):DataFrame={
     val recommend = makeRecommendations(modelnew, userID, 10)
     val recommendedArtistIDs = recommend.select("artist").as[Int].collect()
     val ss = artistByID.filter($"id" isin (recommendedArtistIDs:_*)).toDF("artist","name")
     ss.join(recommend,"artist").toDF("artist","name","prediction")
}

//查看一下 5个用户的听歌情况
//      val recommendedArtistIds  = toRecommend.select("artist").as[Int].collect()
//      artistByID.filter($"id" isin (recommendedArtistIds:_*)).select("name")
      //(toRecommend,topReconmend,artists)

//userID = 1000112

//val topRecommendations = makeRecommendations(modelnew,userID,5)
//topRecommendations.show()

//val existingArtistIDs = trainData.filter($"user" === userID). //找到用户1000112对应的行
//select("artist").as[Int].collect() //收集艺术家ID的整型集合
//artistByID.filter($"id" isin (existingArtistIDs:_*)).show() //过滤艺术家；_*变长参数语法

//val topRecommendations = makeRecommendations(modelnew, userID, 5)
//topRecommendations.show()
//val recommendedArtistIDs = topRecommendations.select("artist").as[Int].collect()
//artistByID.filter($"id" isin (recommendedArtistIDs:_*)).show()

 modelnew.userFactors.unpersist()


//模型评估

val Array(trainData1, testData) = trainData.randomSplit(Array(0.8, 0.2))
trainData1.cache()
testData.cache()

val allArtistIDs =trainData.select("artist").as[Int].distinct().collect()

val bAllArtistIDs = sc.broadcast(allArtistIDs)
//auc 进行评价

def auc( posiData: DataFrame,bAllArtistIDs: Broadcast[Array[Int]], 
                   predicFunc:(DataFrame=> DataFrame)): Double={
      
// 进行预测
      val  positivePredictions = predicFunc(posiData.select("user","artist"))
                      .withColumnRenamed("prediction","positivePrediction")
//从不在正确user中挑选随机的负面集
    val negativeData = posiData.select("user", "artist").as[(Int,Int)].
                                   groupByKey { case (user, _) => user }.
                                   flatMapGroups { case (userID, userIDAndPosArtistIDs) =>
                                   val random = new Random()
                                  val posItemIDSet = userIDAndPosArtistIDs.map { case (_, artist) => artist }.toSet
                                  val negative = new ArrayBuffer[Int]()
                                  val allArtistIDs = bAllArtistIDs.value
                                  
                                  var i = 0
           while (i < allArtistIDs.length && negative.size < posItemIDSet.size) {
           val artistID = allArtistIDs(random.nextInt(allArtistIDs.length))
                                 // Only add new distinct IDs
                                    if (!posItemIDSet.contains(artistID)) {
                                                negative += artistID
                                      }
                                        i += 1
                                        }
                                  negative.take(10)
                                  negative.map(artistID => (userID, artistID))
                               }.toDF("user", "artist")

//预测负面集
      val negativePredictions = predicFunc(negativeData).
      withColumnRenamed("prediction", "negativePrediction")

//通过用户将 posi和nega链接
    val joinedPredictions = positivePredictions.join(negativePredictions, "user").
      select("user", "positivePrediction", "negativePrediction").cache()
      joinedPredictions.show()

 // Count the number of pairs per user
       val allCounts = joinedPredictions.
        groupBy("user").agg(count(lit("1")).as("total")).
        select("user", "total")
        allCounts.show()

    // Count the number of correctly ordered pairs per user
        val correctCounts = joinedPredictions.
        filter($"positivePrediction" > $"negativePrediction").
        groupBy("user").agg(count("user").as("correct")).
        select("user", "correct")
        correctCounts.show()

        allCounts.join(correctCounts, "user").
        select($"user", ($"correct" / $"total").as("auc")).
        agg(mean("auc")).show()

    // Combine these, compute their ratio, and average over all users
        val meanAUC = allCounts.join(correctCounts, "user").
        select($"user", ($"correct" / $"total").as("auc")).
        agg(mean("auc")).
        as[Double].first()

         joinedPredictions.unpersist()

         meanAUC
  }


val auc_score = auc(testData, bAllArtistIDs, modelnew.transform)

def predictMostListened(train: DataFrame)(allData: DataFrame): DataFrame = {
    val listenCounts = train.groupBy("artist").
    agg(sum("count").as("prediction")).
    select("artist", "prediction")
    allData.
    join(listenCounts, Seq("artist"), "left_outer").
    select("user", "artist", "prediction")
}

val mostListenedAUC = auc(testData, bAllArtistIDs, predictMostListened(trainData))
println(mostListenedAUC)

/*
val evaluations =
      for (rank     <- Seq(5,  20);
           regParam <- Seq(1.0, 0.1,0.001);
           alpha    <- Seq(1.0, 40.0)) //这里表示为3层嵌套for循环，rank循环里嵌套着regParam循环，里面再嵌套着alpha循环
      yield {
        val model = new ALS().
          setSeed(Random.nextLong()).
          setImplicitPrefs(true).
          setRank(rank).setRegParam(regParam).
          setAlpha(alpha).setMaxIter(5).
          setUserCol("user").setItemCol("artist").
          setRatingCol("count").setPredictionCol("prediction").
          fit(trainData1)

        val auc_score = auc(testData, bAllArtistIDs, model.transform)

        model.userFactors.unpersist() // 立即释放模型占用的资源
        model.itemFactors.unpersist()

        (auc_score, (rank, regParam, alpha))
      }

    evaluations.sorted.reverse.foreach(println) // 按第一个值（AUC）的降序排列并输出
*/

//将数据写入hdfs

//modelnew.save("/usr/local/spark/Model/modelnew")
//artistByID.write.format("csv").save("/usr/local/spark/Model/artistByID")
//trainData.write.format("csv").save("/usr/local/spark/Model/trainData")



